{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sequence to Sequence Transformer model optimisation using TFMOT (Quantization Aware Training)\n",
    "\n",
    "Example notebook to demonstrate how TFMOT can be used for optimising complex sequence to sequence transformer models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Background\n",
    "\n",
    "The sequence to sequence transformer is one of the initial transformer model architectures. The core idea behind the Transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. The paper called [\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf) might give a deeper insight into transformer model and their self-attention mechanism.\n",
    "\n",
    "<img src=\"https://deepfrench.gitlab.io/deep-learning-project/resources/transformer.png\" alt=\"Sequence to sequence transformer\" width=\"1000\" align=\"center\" title=\"Transformer\">\n",
    "\n",
    "[1] The above image was taken from [here](https://deepfrench.gitlab.io/deep-learning-project/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### In this notebook:\n",
    "\n",
    "* The aim of this tutorial is to first train the Transformer model from [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n",
    "* Re-write the above model as a funtional FP32 model\n",
    "* Perform Quantized Aware Training (QAT) on the FP32 model\n",
    "* Create and test the tflite model generated from the FP32 model after performing QAT on it. \n",
    "\n",
    "Note: This tutorial has re-used some code and explanation from the original [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### TFMOT limitations\n",
    "- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)\n",
    "- Custom subclassed layers are not supported. (Clustering & QAT)\n",
    "    - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).\n",
    "    - QAT works correctly if the subclassed layer performs only 1 operation.\n",
    "- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)\n",
    "    - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import random\n",
    "import tempfile\n",
    "import zipfile\n",
    "import re\n",
    "import os\n",
    "import nltk\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.layers import TextVectorization\n",
    "import tensorflow_model_optimization as tfmot\n",
    "from collections import defaultdict\n",
    "\n",
    "def reset_random_seeds():\n",
    "   os.environ['PYTHONHASHSEED']=str(2)\n",
    "   tf.random.set_seed(2)\n",
    "   np.random.seed(2)\n",
    "   random.seed(2)\n",
    "\n",
    "reset_random_seeds()\n",
    "\n",
    "print('TensorFlow version: {}'.format(tf.__version__))\n",
    "print('TFMOT version: {}'.format(tfmot.__version__))\n",
    "print(\"NLTK verison: {}\".format(nltk.__version__))\n",
    "print(\"Numpy version: {}\".format(np.__version__))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Downloading the data\n",
    "\n",
    "The dataset used here is English to Spanish translation dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_file = keras.utils.get_file(\n",
    "    fname=\"spa-eng.zip\",\n",
    "    origin=\"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\",\n",
    "    extract=True,\n",
    ")\n",
    "text_file = pathlib.Path(text_file).parent / \"spa-eng\" / \"spa.txt\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Parsing the data\n",
    "\n",
    "Each target sentence (which is in Spanish) has `[start]` and `[end]` token prepended and appended, respectively, at this stage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(text_file) as f:\n",
    "    lines = f.read().split(\"\\n\")[:-1]\n",
    "text_pairs = []\n",
    "for line in lines:\n",
    "    eng, spa = line.split(\"\\t\")\n",
    "    spa = \"[start] \" + spa + \" [end]\"\n",
    "    text_pairs.append((eng, spa))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(5):\n",
    "    print(random.choice(text_pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the dataset into train, test and validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.shuffle(text_pairs)\n",
    "num_val_samples = int(0.15 * len(text_pairs))\n",
    "num_train_samples = len(text_pairs) - 2 * num_val_samples\n",
    "train_pairs = text_pairs[:num_train_samples]\n",
    "val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]\n",
    "test_pairs = text_pairs[num_train_samples + num_val_samples :]\n",
    "\n",
    "print(f\"{len(text_pairs)} total pairs\")\n",
    "print(f\"{len(train_pairs)} training pairs\")\n",
    "print(f\"{len(val_pairs)} validation pairs\")\n",
    "print(f\"{len(test_pairs)} test pairs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Vectorizing the text data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Vectorization refers to the preprocessing step where text features are mapped to integer sequences where each integer represents the index of a word in a vocubulary. For this, [`tf.keras.layers.TextVecorization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer is used.\n",
    "\n",
    "In our case, vectorization for English sequences is a little different from that of Spanish sequences:\n",
    "\n",
    "- For English string sequences, default standardization is used which strips all punctuation characters\n",
    "- For Spanish string sequences, custom standardization is used which strips all characters which are not in `{` a-z.?!,¿[]`}`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 15000\n",
    "seq_len = 20\n",
    "batch_size = 64\n",
    "embed_dim = 256\n",
    "latent_dim = 2048\n",
    "num_heads = 8\n",
    "\n",
    "\n",
    "def custom_standardization(input_string):\n",
    "    lowercase = tf.strings.lower(input_string)\n",
    "    # The following regex replaces a character with \"\"\n",
    "    # which is not one of the following:\n",
    "    # 1. Lower case alphabet\n",
    "    # 2. Space\n",
    "    # 3. Is on of these characters: \".\", \"?\", \"!\", \",\", \"¿\", \"[\", \"]\"\n",
    "    return tf.strings.regex_replace(lowercase, '[^ a-z.?!,¿\\[\\]]', \"\")\n",
    "\n",
    "\n",
    "eng_vectorization = TextVectorization(\n",
    "    max_tokens=vocab_size, output_mode=\"int\", output_sequence_length=seq_len,\n",
    ")\n",
    "spa_vectorization = TextVectorization(\n",
    "    max_tokens=vocab_size,\n",
    "    output_mode=\"int\",\n",
    "    output_sequence_length=seq_len + 1,\n",
    "    standardize=custom_standardization,\n",
    ")\n",
    "\n",
    "train_eng_texts = [pair[0] for pair in train_pairs]\n",
    "train_spa_texts = [pair[1] for pair in train_pairs]\n",
    "eng_vectorization.adapt(train_eng_texts)\n",
    "spa_vectorization.adapt(train_spa_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At each training step, the model will seek to predict target words N+1 (and beyond) using the source (or the input) sentence and the target words 0 to N. For this reason, we need (`inputs`, `targets`)\n",
    "\n",
    "- `inputs`:\n",
    "\n",
    "    After vectorization, our dataset is formatted to include the following four in the `inputs` (`inputs` is essentially a list of four inputs):\n",
    "\n",
    "    * encoder_inputs : which contains the vectorized english sentence data\n",
    "    * decoder_inputs : which contains the vectorized spanish (target) sentence data, i.e. target_sentence[:, :-1]. It is also the target sentence \"so far\", that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence. \n",
    "    * encoder_masks  : which contains the corresponding mask data for encoder_inputs\n",
    "    * decoder_masks  : which contains the corresponding mask data for decoder_inputs\n",
    "\n",
    "    Please note that the two mask inputs are only required for the custom FP32 functional model as the original keras model is able to generate it's own mask. Therefore, the original model tends to ignore the two mask inputs (user doesn't need to worry about this). <br><br>\n",
    "    \n",
    "- `targets`:\n",
    "\n",
    "    After vectorization, our dataset is formatted to assign the target sentence offset by one (i.e. target_sentence[:, 1:]) as the `targets`. In other words this is what model will try to predict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_dataset(eng, spa):\n",
    "    eng = eng_vectorization(eng)\n",
    "    spa = spa_vectorization(spa)\n",
    "\n",
    "    # Create input masks\n",
    "    encoder_masks=tf.cast(tf.not_equal(np.int64(0),eng),tf.float32)\n",
    "    decoder_masks=tf.cast(tf.not_equal(np.int64(0),spa[:, :-1]),tf.float32)\n",
    "    \n",
    "    return ({\"encoder_inputs\": eng, \"encoder_masks\": encoder_masks, \"decoder_inputs\": spa[:, :-1], \"decoder_masks\": decoder_masks}, spa[:, 1:])\n",
    "\n",
    "def make_dataset(pairs, batch_size=64):\n",
    "    eng_texts, spa_texts = zip(*pairs)\n",
    "    eng_texts = list(eng_texts)\n",
    "    spa_texts = list(spa_texts)\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((eng_texts, spa_texts))\n",
    "    dataset = dataset.batch(batch_size, drop_remainder=True)\n",
    "    dataset = dataset.map(format_dataset)\n",
    "    \n",
    "    return dataset.shuffle(2048).prefetch(16).cache()\n",
    "\n",
    "\n",
    "train_ds = make_dataset(train_pairs, batch_size)\n",
    "val_ds = make_dataset(val_pairs, batch_size)\n",
    "test_ds = make_dataset(test_pairs, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for inputs, targets in train_ds.take(1):\n",
    "    print(f'inputs[\"encoder_inputs\"].shape: {inputs[\"encoder_inputs\"].shape}')\n",
    "    print(f'inputs[\"decoder_inputs\"].shape: {inputs[\"decoder_inputs\"].shape}')\n",
    "    print(f'inputs[\"encoder_masks\"].shape: {inputs[\"encoder_masks\"].shape}')\n",
    "    print(f'inputs[\"decoder_masks\"].shape: {inputs[\"decoder_masks\"].shape}')\n",
    "    print(f\"targets.shape: {targets.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Utility functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Typically, BLEU score is used to measure the quality of a translation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bleu_score(real_text, predicted_text):\n",
    "    '''Get BLEU score'''\n",
    "    return (nltk.translate.bleu_score.corpus_bleu(real_text,predicted_text))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For decoding (or in other words translating a source sentence to a target sentenceg), we provide a vectorized source sentence as `encoder_inputs` and a vecotrized `[start]` token (ofcourse, padded to match the right sequence length) as the `decoder_inputs`, then we repeatedly generated the next token, until we hit the token `[end]`.\n",
    "\n",
    "A key thing to note is that in the custom FP32 functional model used in this notebook `encoder_masks` and `decoder_masks` are also fed into the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_text_result(model, num_samples_to_eval =200, no_input_masks=False):\n",
    "    '''\n",
    "    Function to calculate BLEU score on test set\n",
    "\n",
    "    num_samples_to_eval: Represents the total number of test sentences to\n",
    "                         consider during evaluation. If you want the entire \n",
    "                         test set to be used for evaluation then set \n",
    "                         max_sample = -1\n",
    "    \n",
    "    no_input_masks: Set as True for the original transformer model from\n",
    "                    keras example\n",
    "    '''\n",
    "\n",
    "    spa_vocab = spa_vectorization.get_vocabulary()\n",
    "    spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))\n",
    "    max_decoded_sentence_length = 20\n",
    "\n",
    "    def decode_sequence_func(input_sentence):\n",
    "\n",
    "        tokenized_input_sentence = eng_vectorization([input_sentence])\n",
    "        encoder_mask = tf.cast(tf.not_equal(np.int64(0),tokenized_input_sentence), tf.float32)\n",
    "\n",
    "        decoded_sentence = \"[start]\"\n",
    "        for i in range(max_decoded_sentence_length):\n",
    "            tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n",
    "            decoder_mask=tf.cast(tf.not_equal(np.int64(0),tokenized_target_sentence), tf.float32)\n",
    "            if no_input_masks:\n",
    "                predictions = model([tokenized_input_sentence, tokenized_target_sentence])\n",
    "            else:\n",
    "                predictions = model([tokenized_input_sentence, encoder_mask, tokenized_target_sentence,decoder_mask])\n",
    "            sampled_token_index = np.argmax(predictions[0, i, :])\n",
    "            sampled_token = spa_index_lookup[sampled_token_index]\n",
    "            decoded_sentence += \" \" + sampled_token\n",
    "\n",
    "            if sampled_token == \"[end]\":\n",
    "                break\n",
    "\n",
    "        return decoded_sentence\n",
    "\n",
    "\n",
    "    hypothesis= []\n",
    "    references = []\n",
    "    test_sample_count = sum(1 for e in test_pairs) \n",
    "    progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)\n",
    "\n",
    "    for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):\n",
    "        translated = decode_sequence_func(inp)\n",
    "        target=target.lower()\n",
    "        target=re.sub('[^ a-z.?!,¿\\[\\]]', \"\",target)\n",
    "        hypothesis.append(translated.split()[1:-1])\n",
    "        references.append([target.split()[1:-1]])\n",
    "        progbar.update(step + 1)\n",
    "\n",
    "    print(str(\"Bleu Score: \") + str(bleu_score(references[:], hypothesis[:])))\n",
    "\n",
    "    # Print first 10 actual and predicted spanish translation for sanity check\n",
    "    for i in range(10):\n",
    "        print(references[i][0])\n",
    "        print(hypothesis[i])\n",
    "        print(\"-----------------------/n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suggestion: While trying to run inference on a tflite file please make sure that the scale, zero_point and data type are correct for the inputs and outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_text_result_tflite(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):\n",
    "    '''\n",
    "    Function to calculate BLEU score for a given tflite file on the test set\n",
    "\n",
    "    model_path: Path to the tflite file\n",
    "\n",
    "    input_type: Could be float32 or int8/32. If the inputs in tflite graph\n",
    "                are float32 set this value to 'float32' but if inputs are\n",
    "                int8 (mask inputs) and int32 (non-maks inputs) set this\n",
    "                value to 'int8/32'.\n",
    "\n",
    "    output_type: Could be float32 or int8. If the outputs in tflite graph\n",
    "                 are float32 set this value to 'float32' but if output\n",
    "                 are int8 set this value to 'int8'.\n",
    "                \n",
    "    num_samples_to_eval: Evaluation of entire test set will take a lot\n",
    "                         of time therefore, only first 200 samples are \n",
    "                         evaluated. To evaluate the entire test-set, \n",
    "                         set the value below to a negative value, e.g.\n",
    "                         -1\n",
    "    '''\n",
    "    assert(input_type in ['float32', 'int8/32']), \"input_type not supported\"\n",
    "    assert(output_type in ['float32', 'int8']), \"output_type not supported\"\n",
    "\n",
    "    print('Performing BLEU evaluation for tflite file at {}'.format(model_path))\n",
    "\n",
    "    spa_vocab = spa_vectorization.get_vocabulary()\n",
    "    spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))\n",
    "    max_decoded_sentence_length = seq_len\n",
    "\n",
    "    interpreter = tf.lite.Interpreter(model_path=model_path)\n",
    "\n",
    "    input_details = interpreter.get_input_details()\n",
    "    output_details = interpreter.get_output_details()\n",
    "\n",
    "    input_scale_1, input_zero_point_1 = input_details[0]['quantization']\n",
    "    input_scale_2, input_zero_point_2 = input_details[1]['quantization']\n",
    "    input_scale_3, input_zero_point_3 = input_details[2]['quantization']\n",
    "    input_scale_4, input_zero_point_4 = input_details[3]['quantization']\n",
    "    output_scale, output_zero_point = output_details[0]['quantization']\n",
    "\n",
    "    interpreter.allocate_tensors()\n",
    "\n",
    "    def decode_sequence_func(input_sentence):\n",
    "\n",
    "        input_1 = eng_vectorization([input_sentence])\n",
    "        input_2 = tf.cast(tf.not_equal(np.int64(0),input_1), tf.float32)\n",
    "        if input_type == 'int8/32':\n",
    "            input_2 = input_2/ input_scale_2 + input_zero_point_2\n",
    "\n",
    "        decoded_sentence = \"[start]\"\n",
    "\n",
    "        for i in range(max_decoded_sentence_length):\n",
    "            input_3 = spa_vectorization([decoded_sentence])[:, :-1]\n",
    "            input_4=tf.cast(tf.not_equal(np.int64(0),input_3), tf.float32)\n",
    "\n",
    "            # Set input tensor\n",
    "            interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))\n",
    "\n",
    "            # Set input tensor\n",
    "            interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))\n",
    "\n",
    "            # Set input tensor\n",
    "            interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))\n",
    "\n",
    "            # Set input tensor\n",
    "            if input_type == 'int8/32':\n",
    "                input_4 = input_4/ input_scale_4 + input_zero_point_4\n",
    "            interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))\n",
    "\n",
    "            interpreter.invoke()\n",
    "            \n",
    "            # Get output tensor\n",
    "            output_data = interpreter.get_tensor(output_details[0]['index'])\n",
    "            predictions = output_data.astype(np.float32)\n",
    "            if output_type == 'int8':\n",
    "                predictions = output_scale * (predictions - output_zero_point)\n",
    "            \n",
    "            sampled_token_index = np.argmax(predictions[0, i, :])\n",
    "            sampled_token = spa_index_lookup[sampled_token_index]\n",
    "            decoded_sentence += \" \" + sampled_token\n",
    "\n",
    "            if sampled_token == \"[end]\":\n",
    "                break\n",
    "\n",
    "        return decoded_sentence\n",
    "\n",
    "\n",
    "    hypothesis= []\n",
    "    references = []\n",
    "    test_sample_count = sum(1 for e in test_pairs) \n",
    "    progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)\n",
    "\n",
    "    for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):\n",
    "        translated = decode_sequence_func(inp)\n",
    "        target=target.lower()\n",
    "        target=re.sub('[^ a-z.?!,¿\\[\\]]', \"\",target)\n",
    "        hypothesis.append(translated.split()[1:-1])\n",
    "        references.append([target.split()[1:-1]])\n",
    "        progbar.update(step + 1)\n",
    "\n",
    "    print(str(\"Bleu Score: \") + str(bleu_score(references[:], hypothesis[:])))\n",
    "    for i in range(10):\n",
    "        print(references[i][0])\n",
    "        print(hypothesis[i])\n",
    "        print(\"-----------------------/n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tflite_accuracy(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):\n",
    "    '''\n",
    "    Function to calculate accuracy of a given tflite file on the test set\n",
    "\n",
    "    model_path: Path to the tflite file\n",
    "\n",
    "    input_type: Could be float32 or int8/32. If the inputs in tflite graph\n",
    "                are float32 set this value to 'float32' but if inputs are\n",
    "                int8 (mask inputs) and int32 (non-maks inputs) set this\n",
    "                value to 'int8/64'.\n",
    "\n",
    "    output_type: Could be float32 or int8. If the outputs in tflite graph\n",
    "                 are float32 set this value to 'float32' but if output\n",
    "                 are int8 set this value to 'int8'.\n",
    "                \n",
    "    num_samples_to_eval: Evaluation of entire test set will take a lot\n",
    "                         of time therefore, only first 200 samples are \n",
    "                         evaluated. To evaluate the entire test-set, \n",
    "                         set the value below to a negative value, e.g.\n",
    "                         -1\n",
    "    '''\n",
    "    assert(input_type in ['float32', 'int8/32']), \"input_type not supported\"\n",
    "    assert(output_type in ['float32', 'int8']), \"output_type not supported\"\n",
    "\n",
    "    print('Performing accuracy evaluation for tflite file at {}'.format(model_path))\n",
    "\n",
    "    interpreter = tf.lite.Interpreter(model_path=model_path)\n",
    "\n",
    "    input_details = interpreter.get_input_details()\n",
    "    output_details = interpreter.get_output_details()\n",
    "\n",
    "    input_scale_1, input_zero_point_1 = input_details[0]['quantization']\n",
    "    input_scale_2, input_zero_point_2 = input_details[1]['quantization']\n",
    "    input_scale_3, input_zero_point_3 = input_details[2]['quantization']\n",
    "    input_scale_4, input_zero_point_4 = input_details[3]['quantization']\n",
    "    output_scale, output_zero_point = output_details[0]['quantization']\n",
    "    interpreter.allocate_tensors()\n",
    "\n",
    "    test_ds_tflite = make_dataset(test_pairs, 1)\n",
    "    accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')\n",
    "    progbar = tf.keras.utils.Progbar(sum(1 for e in test_ds_tflite) if num_samples_to_eval == -1 else num_samples_to_eval, stateful_metrics=['accuracy'])\n",
    "\n",
    "    for step, (input, target) in enumerate(test_ds_tflite):\n",
    "\n",
    "        # Set input tensor\n",
    "        input_1 = input['encoder_inputs']\n",
    "        interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))\n",
    "\n",
    "        # Set input tensorprogress bars for loopp python\n",
    "        input_2=input['encoder_masks']\n",
    "        if input_type == 'int8/32':\n",
    "            input_2 = tf.cast(input_2, tf.float32)\n",
    "            input_2 = input_2/ input_scale_2 + input_zero_point_2\n",
    "        interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))\n",
    "\n",
    "        # Set input tensor\n",
    "        input_3 = input['decoder_inputs']\n",
    "        interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))\n",
    "\n",
    "        # Set input tensor\n",
    "        input_4=input['decoder_masks']\n",
    "        if input_type == 'int8/32':\n",
    "            input_4 = tf.cast(input_4, tf.float32)\n",
    "            input_4 = input_4/ input_scale_4 + input_zero_point_4\n",
    "        interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))\n",
    "        interpreter.invoke()\n",
    "        \n",
    "        # Get output tensor\n",
    "        output_data = interpreter.get_tensor(output_details[0]['index'])\n",
    "        output_data = output_data.astype(np.float32)\n",
    "        if output_type == 'int8':\n",
    "            output_data = output_scale * (output_data - output_zero_point)\n",
    "        \n",
    "        # Update accuracy\n",
    "        mask = input['decoder_inputs']\n",
    "        accuracy.update_state(target, output_data, mask)\n",
    "        progbar.update(step + 1, values=[('accuracy', accuracy.result().numpy())])\n",
    "        \n",
    "        if step == num_samples_to_eval:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the following function to get the size of the tflite file when zipped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gzipped_model_size(file):\n",
    "  '''Returns the size of a gzipped tflite file in kilobytes'''\n",
    "\n",
    "  _, zipped_file = tempfile.mkstemp('.zip')\n",
    "  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
    "    f.write(file)\n",
    "\n",
    "  return os.path.getsize(zipped_file)/1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. Functions related to Training the model\n",
    "\n",
    "The loss function used is Masked Sparse Categorical Crossentropy loss (which uses the `tf.keras.losses.SparseCategoricalCrossentropy` but with masks).\n",
    "The loss function needs masks to be propogated correctly through the model layers down to the loss function which, the custom FP32 model wasn't able to do correctly therefore, a custom training loop was needed to calculate the loss correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 11\n",
    "\n",
    "def evaluate(model_to_eval, training=False):\n",
    "\n",
    "    val_loss = tf.keras.metrics.SparseCategoricalCrossentropy()\n",
    "    val_acc = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "\n",
    "    @tf.function\n",
    "    def eval_step(inp, y_true):\n",
    "        preds = model_to_eval(inp, training=training)\n",
    "        # masked loss\n",
    "        val_loss.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32))  \n",
    "        # masked accuracy\n",
    "        val_acc.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32))  \n",
    "\n",
    "    for step, (inp, y_true) in enumerate(val_ds):\n",
    "        eval_step(inp, y_true)\n",
    "\n",
    "    return {'loss': val_loss.result().numpy(), 'accuracy': val_acc.result().numpy()}\n",
    "\n",
    "\n",
    "def train(model_to_train, save_best_weights =True, model_type='original', lr=1e-3, epochs = epochs):\n",
    "\n",
    "    if model_type == 'original':\n",
    "        ckpt_path = './eng_spa_transformer_qat_tutorial_model.h5'\n",
    "    elif model_type == 'fp32':\n",
    "        ckpt_path = './eng_spa_transformer_qat_tutorial_fp32_model.h5'\n",
    "    elif model_type == 'qat':\n",
    "        ckpt_path = './eng_spa_transformer_qat_tutorial_qat_model.h5'\n",
    "    else:\n",
    "        print('Please select the correct model type!!')\n",
    "        return None\n",
    "    \n",
    "    print('Training (save_best_weights={}, model_type={})'.format(save_best_weights, model_type))\n",
    "\n",
    "    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "    optimiser = tf.keras.optimizers.Adam(learning_rate=lr)\n",
    "    train_acc = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "    model_to_train.optimizer = optimiser\n",
    "\n",
    "    @tf.function\n",
    "    def train_step(inp, y_true):\n",
    "        mask =tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)\n",
    "        preds=None\n",
    "        loss=None\n",
    "        \n",
    "        with tf.GradientTape() as tape:\n",
    "            preds = model_to_train(inp, training=True)\n",
    "            # Masked loss\n",
    "            loss = loss_fn(y_true, preds, mask)\n",
    "            grads = tape.gradient(loss, model_to_train.trainable_weights)\n",
    "            optimiser.apply_gradients(zip(grads, model_to_train.trainable_weights))\n",
    "\n",
    "        # Masked accuracy    \n",
    "        train_acc.update_state(y_true, preds, mask)\n",
    "        return loss\n",
    "\n",
    "    max_val = float('-inf')\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        print('Epoch {}/{}'.format(epoch + 1, epochs), flush=True)\n",
    "        # Train\n",
    "        progbar = tf.keras.utils.Progbar(len(train_ds), interval=.5,\n",
    "                                        stateful_metrics=['acc'])        \n",
    "\n",
    "        for step, (inp, y_true) in enumerate(train_ds):\n",
    "                loss = train_step(inp, y_true)\n",
    "                progbar.update(step + 1, values=[('loss', loss),\n",
    "                                                ('acc', train_acc.result())])\n",
    "\n",
    "        # Evaluate\n",
    "        val_results = evaluate(model_to_train)\n",
    "\n",
    "        validation_accuracy = val_results['accuracy']\n",
    "        print('Validation accuracy: {}'.format(validation_accuracy))\n",
    "\n",
    "        if save_best_weights and validation_accuracy > max_val:\n",
    "            \n",
    "            print('Best validation accuracy so far, saving weights')\n",
    "            model_to_train.save_weights(ckpt_path)\n",
    "            max_val = validation_accuracy\n",
    "\n",
    "        train_acc.reset_states()        \n",
    "\n",
    "    if not save_best_weights:\n",
    "        model_to_train.save_weights(ckpt_path)\n",
    "    # Load weights\n",
    "    model_to_train.load_weights(ckpt_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7. Building the original Transformer Keras model mentioned in the [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(a) Define the custom layers for the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerEncoder(layers.Layer):\n",
    "    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n",
    "        super(TransformerEncoder, self).__init__(**kwargs)\n",
    "        self.embed_dim = embed_dim\n",
    "        self.dense_dim = dense_dim\n",
    "        self.num_heads = num_heads\n",
    "        self.attention = layers.MultiHeadAttention(\n",
    "            num_heads=num_heads, key_dim=embed_dim\n",
    "        )\n",
    "        self.dense_proj = keras.Sequential(\n",
    "            [layers.Dense(dense_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
    "        )\n",
    "        self.layernorm_1 = layers.LayerNormalization()\n",
    "        self.layernorm_2 = layers.LayerNormalization()\n",
    "        self.supports_masking = True\n",
    "\n",
    "    def call(self, inputs, mask=None):\n",
    "        if mask is not None:\n",
    "            padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=\"int32\")\n",
    "        attention_output = self.attention(\n",
    "            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n",
    "        )\n",
    "        proj_input = self.layernorm_1(inputs + attention_output)\n",
    "        proj_output = self.dense_proj(proj_input)\n",
    "        return self.layernorm_2(proj_input + proj_output)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'embed_dim': self.embed_dim,\n",
    "                       'dense_dim': self.dense_dim,\n",
    "                       'num_heads': self.num_heads})\n",
    "        return config\n",
    "\n",
    "\n",
    "class PositionalEmbedding(layers.Layer):\n",
    "    def __init__(self, seq_len, vocab_size, embed_dim, **kwargs):\n",
    "        super(PositionalEmbedding, self).__init__(**kwargs)\n",
    "        self.token_embeddings = layers.Embedding(\n",
    "            input_dim=vocab_size, output_dim=embed_dim\n",
    "        )\n",
    "        self.position_embeddings = layers.Embedding(\n",
    "            input_dim=seq_len, output_dim=embed_dim\n",
    "        )\n",
    "        self.seq_len = seq_len\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_dim = embed_dim\n",
    "\n",
    "    def call(self, inputs):\n",
    "        length = tf.shape(inputs)[-1]\n",
    "        positions = tf.range(start=0, limit=length, delta=1)\n",
    "        embedded_tokens = self.token_embeddings(inputs)\n",
    "        embedded_positions = self.position_embeddings(positions)\n",
    "        return embedded_tokens + embedded_positions\n",
    "\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return tf.math.not_equal(inputs, 0)\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'embed_dim': self.embed_dim,\n",
    "                       'vocab_size': self.vocab_size,\n",
    "                       'seq_len': self.seq_len})\n",
    "        return config\n",
    "\n",
    "\n",
    "class TransformerDecoder(layers.Layer):\n",
    "    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):\n",
    "        super(TransformerDecoder, self).__init__(**kwargs)\n",
    "        self.embed_dim = embed_dim\n",
    "        self.latent_dim = latent_dim\n",
    "        self.num_heads = num_heads\n",
    "        self.attention_1 = layers.MultiHeadAttention(\n",
    "            num_heads=num_heads, key_dim=embed_dim\n",
    "        )\n",
    "        self.attention_2 = layers.MultiHeadAttention(\n",
    "            num_heads=num_heads, key_dim=embed_dim\n",
    "        )\n",
    "        self.dense_proj = keras.Sequential(\n",
    "            [layers.Dense(latent_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
    "        )\n",
    "        self.layernorm_1 = layers.LayerNormalization()\n",
    "        self.layernorm_2 = layers.LayerNormalization()\n",
    "        self.layernorm_3 = layers.LayerNormalization()\n",
    "        self.supports_masking = True\n",
    "\n",
    "    def call(self, inputs, encoder_outputs, mask=None):\n",
    "        causal_mask = self.get_causal_attention_mask(inputs)\n",
    "        if mask is not None:\n",
    "            padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n",
    "            padding_mask = tf.minimum(padding_mask, causal_mask)\n",
    "\n",
    "        attention_output_1 = self.attention_1(\n",
    "            query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
    "        )\n",
    "        out_1 = self.layernorm_1(inputs + attention_output_1)\n",
    "\n",
    "        attention_output_2 = self.attention_2(\n",
    "            query=out_1,\n",
    "            value=encoder_outputs,\n",
    "            key=encoder_outputs,\n",
    "            attention_mask=padding_mask,\n",
    "        )\n",
    "        out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
    "\n",
    "        proj_output = self.dense_proj(out_2)\n",
    "        return self.layernorm_3(out_2 + proj_output)\n",
    "\n",
    "    def get_causal_attention_mask(self, inputs):\n",
    "        input_shape = tf.shape(inputs)\n",
    "        batch_size, seq_len = input_shape[0], input_shape[1]\n",
    "        i = tf.range(seq_len)[:, tf.newaxis]\n",
    "        j = tf.range(seq_len)\n",
    "        mask = tf.cast(i >= j, dtype=\"int32\")\n",
    "        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
    "        mult = tf.concat(\n",
    "            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],\n",
    "            axis=0,\n",
    "        )\n",
    "        return tf.tile(mask, mult)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'embed_dim': self.embed_dim,\n",
    "                       'latent_dim': self.latent_dim,\n",
    "                       'num_heads': self.num_heads})\n",
    "        return config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(b) Build the end-to-end model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_encoder_decoder_model():\n",
    "    encoder_inputs = keras.Input(shape=(20,), dtype=\"int64\", name=\"encoder_inputs\")\n",
    "    x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(encoder_inputs)\n",
    "    encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)\n",
    "    decoder_inputs = keras.Input(shape=(20,), dtype=\"int64\", name=\"decoder_inputs\")\n",
    "    encoded_seq_inputs = encoder_outputs\n",
    "    x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(decoder_inputs)\n",
    "    x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n",
    "    x = layers.Dropout(0.5)(x)\n",
    "    decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n",
    "    \n",
    "    transformer = keras.Model(\n",
    "        [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n",
    "    )\n",
    "\n",
    "    return transformer\n",
    "\n",
    "transformer = get_encoder_decoder_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(c) Training the original Transformer model from Keras example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer.summary()\n",
    "train(transformer, model_type='original')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(d) Evaluate performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get BLEU score on test set for original transformer model\n",
    "get_text_result(transformer, no_input_masks=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get accuracy on test set for the original transformer model from Keras example\n",
    "evaluate(transformer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8. Create FP32 Function Model for the Transformer model\n",
    "\n",
    "Custom Keras layers must be defined for all of the low-level TensorFlow operators (each must only contain a single operation for QAT).\n",
    "\n",
    "Since none of these will have any prunable weights, first we create a base prunable layer class to extend, instead of `tf.keras.layers.Layer`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(a) Create a base prunable layer class "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrunableLayer(tf.keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):\n",
    "    def get_prunable_weights(self): return []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(b) Define low level TensorFlow operations as Keras subclassed layers\n",
    "\n",
    "Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned or clustered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tanh(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def call(self, x):\n",
    "        return tf.math.tanh(x)\n",
    "\n",
    "\n",
    "class Relu(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        \n",
    "    def call(self, x):\n",
    "        return tf.maximum(0., x)\n",
    "\n",
    "    \n",
    "class MatMul(PrunableLayer):\n",
    "    def __init__(self, transpose_b=False, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.transpose_b = transpose_b       \n",
    "    \n",
    "    def call(self, inputs):\n",
    "        return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'transpose_b': self.transpose_b})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Multiply(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        \n",
    "    def call(self, inputs):\n",
    "        return tf.multiply(*inputs)\n",
    "\n",
    "\n",
    "# Calling Multiply with a scalar input will lead to an error.\n",
    "# Use the following ScalarMultiply class instead.\n",
    "class ScalarMultiply(PrunableLayer):\n",
    "    def __init__(self, scalar, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.scalar = scalar        \n",
    "        \n",
    "    def call(self, x):\n",
    "        return tf.math.multiply(x, self.scalar)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'scalar': self.scalar})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Add(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "       \n",
    "    def call(self, inputs):\n",
    "        return tf.math.add(*inputs)\n",
    "\n",
    "\n",
    "# Calling Add with a scalar input will lead to an error.\n",
    "# Use the following ScalarAdd class instead.\n",
    "class ScalarAdd(PrunableLayer):\n",
    "    def __init__(self, scalar, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.scalar = scalar   \n",
    "    \n",
    "    def call(self, x):\n",
    "        return tf.math.add(x, self.scalar)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'scalar': self.scalar})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Slice(PrunableLayer):\n",
    "    def __init__(self, seq_idx, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.seq_idx = seq_idx      \n",
    "    \n",
    "    def call(self, x):\n",
    "        return x[:, self.seq_idx, ...]\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'seq_idx': self.seq_idx})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Mean(PrunableLayer):\n",
    "    def __init__(self, axes=None, keepdims=True, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.axes=axes\n",
    "        self.keepdims = keepdims      \n",
    "    \n",
    "    def call(self, x):\n",
    "        return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'axes': self.axes,\n",
    "                       'keepdims': self.keepdims})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Subtract(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)       \n",
    "        \n",
    "    def call(self, inputs):\n",
    "        return tf.math.subtract(*inputs)\n",
    "\n",
    "\n",
    "class ScalarSubtract(PrunableLayer):\n",
    "    def __init__(self, scalar, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.scalar = scalar   \n",
    "    \n",
    "    def call(self, x):\n",
    "        return tf.math.subtract(self.scalar,x)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'scalar': self.scalar})\n",
    "        return config\n",
    "\n",
    "\n",
    "class SquaredDiffrence(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)       \n",
    "        \n",
    "    def call(self,inputs):\n",
    "        return tf.math.squared_difference(*inputs)\n",
    "\n",
    "\n",
    "class StopGradient(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        \n",
    "    def call(self, x):\n",
    "        return tf.stop_gradient(x)\n",
    "\n",
    "\n",
    "class RSqrt(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "    \n",
    "    def call(self, x):\n",
    "        return tf.math.rsqrt(x)\n",
    "\n",
    "\n",
    "class Clip(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "       \n",
    "    def call(self, x):\n",
    "        return tf.clip_by_value(x, 0.001, 255.0)\n",
    "\n",
    "\n",
    "class BroadcastToken(PrunableLayer):\n",
    "    \"\"\"Layer to broadcast the class token\"\"\"\n",
    "    def __init__(self, embedding_dim, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.embedding_dim = embedding_dim\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', \n",
    "                                 trainable=True, name='token')\n",
    "        super().build(input_shape)\n",
    "\n",
    "    def call(self, x):\n",
    "        batch_size = tf.shape(x)[0]\n",
    "        return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'embedding_dim': self.embedding_dim})\n",
    "        return config\n",
    "\n",
    "\n",
    "class AddPositionalEmbedding(PrunableLayer):\n",
    "    \"\"\"Layer to add positional embeddings to the tokens\"\"\"\n",
    "    def __init__(self, seq_len, embedding_dim, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.w = self.add_weight(shape=(self.seq_len, self.embedding_dim), initializer= 'uniform',\n",
    "                                 trainable=True, name='pos_emb')\n",
    "        super().build(input_shape)\n",
    "\n",
    "    def call(self, x):\n",
    "        return x + self.w\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})\n",
    "        return config\n",
    "\n",
    "\n",
    "class AddTokenEmbedding(PrunableLayer): \n",
    "    \"\"\"Layer to add token embeddings to the tokens\"\"\"\n",
    "    def __init__(self, vocab_size, embedding_dim, train = True, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.vocab_size = vocab_size\n",
    "        self.train = train\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.w = self.add_weight(shape=(self.vocab_size, self.embedding_dim), initializer= 'uniform',\n",
    "                                 trainable=self.train, name='token_emb')\n",
    "        super().build(input_shape)\n",
    "\n",
    "    def call(self, x):\n",
    "        return tf.gather(self.w,x)\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'embedding_dim': self.embedding_dim, 'vocab_size': self.vocab_size, 'train': self.train})\n",
    "        return config\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        return(input_shape[-1], self.embedding_dim)\n",
    "\n",
    "\n",
    "class Scale(PrunableLayer):\n",
    "    \"\"\"Multiply with gamma (LayerNorm)\"\"\"\n",
    "    def __init__(self, axes, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.axes = axes        \n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        param_shape = [input_shape[dim] for dim in self.axes]\n",
    "        self.w = self.add_weight(name='gamma', shape=param_shape,\n",
    "                                 trainable=True, initializer='ones')\n",
    "        super().build(input_shape)\n",
    "        \n",
    "    def call(self, x):\n",
    "        return tf.multiply(x, self.w)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'axes': self.axes})\n",
    "        return config\n",
    "\n",
    "    \n",
    "class Centre(PrunableLayer):\n",
    "    \"\"\"Add beta (LayerNorm)\"\"\"\n",
    "    def __init__(self, axes, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.axes = axes        \n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        param_shape = [input_shape[dim] for dim in self.axes]\n",
    "        self.w = self.add_weight(name='beta', shape=param_shape,\n",
    "                                 trainable=True, initializer='zeros')\n",
    "        super().build(input_shape)\n",
    "        \n",
    "    def call(self, x):\n",
    "        return tf.math.add(x, self.w)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'axes': self.axes})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Minimum(PrunableLayer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)       \n",
    "        \n",
    "    def call(self,inputs):\n",
    "        return tf.minimum(*inputs)\n",
    "\n",
    "\n",
    "class MinimumScalar(PrunableLayer):\n",
    "    def __init__(self, scalar, **kwargs):\n",
    "        super().__init__(**kwargs)       \n",
    "        self.scalar=scalar\n",
    "\n",
    "    def call(self,inputs):\n",
    "        return tf.minimum(inputs, self.scalar)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'scalar': self.scalar})\n",
    "        return config\n",
    "\n",
    "\n",
    "class MaximumScalar(PrunableLayer):\n",
    "    def __init__(self, scalar, **kwargs):\n",
    "        super().__init__(**kwargs)       \n",
    "        self.scalar=scalar\n",
    "\n",
    "    def call(self,inputs):\n",
    "        return tf.maximum(inputs, self.scalar)\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'scalar': self.scalar})\n",
    "        return config\n",
    "\n",
    "\n",
    "class Cast(PrunableLayer):\n",
    "    def __init__(self, type = tf.int32, **kwargs):\n",
    "        super().__init__(**kwargs)  \n",
    "        self.type=type    \n",
    "\n",
    "    def call(self,inputs):\n",
    "        return tf.cast(inputs, self.type)\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update({'type': self.type})\n",
    "        return config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(c) Define Transormer layers like multiheaded-attention, layer-norm, etc. functionally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def self_attention(query, key, value, n_heads, dim, mask=None, name='mha', block_name=None, out_dim=None):\n",
    "    \"\"\"Multi-head attention layer\"\"\"\n",
    "    depth = dim // n_heads\n",
    "    if out_dim is None: out_dim = query.shape[-1]\n",
    "    q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(query)\n",
    "    k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(key)\n",
    "    v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(value)\n",
    "\n",
    "    q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)\n",
    "    q = tf.keras.layers.Permute((2, 1, 3))(q)\n",
    "    k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)\n",
    "    k = tf.keras.layers.Permute((2, 1, 3))(k)\n",
    "    v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)\n",
    "    v = tf.keras.layers.Permute((2, 1, 3))(v)\n",
    "    qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))\n",
    "\n",
    "    if mask is not None:\n",
    "        if isinstance(mask, tf.Tensor):\n",
    "            qk = ScalarMultiply(mask)(qk)\n",
    "            mask=1. - mask\n",
    "            mask = mask * -10\n",
    "            qk = ScalarAdd(mask)(qk)\n",
    "            \n",
    "        else:\n",
    "            qk = Multiply()([qk, mask])\n",
    "            mask = ScalarSubtract(1.)(mask)\n",
    "            mask = ScalarMultiply(-10)(mask)\n",
    "            qk = Add(name=f'add/{name}')([qk, (mask)])\n",
    "            \n",
    "    attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)\n",
    "    attn_out = MatMul()([attn_weights, v]) \n",
    "    attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)\n",
    "    attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)\n",
    "    out = tf.keras.layers.Dense(out_dim, name=f'{name}/output_dense',  dtype=\"float32\")(attn_out)\n",
    "    \n",
    "    return out, attn_weights\n",
    "\n",
    "def AddPositionalEmbeddingForEncoderDecoder(inputs, seq_len, vocab_size, embed_dim, block_name, freeze):\n",
    "    x = AddTokenEmbedding(vocab_size, embed_dim, train = not freeze, name= ('token_embedding/' + block_name))(inputs)\n",
    "    x = AddPositionalEmbedding(seq_len, embed_dim, name= ('positional_embedding/' + block_name))(x)\n",
    "    return x\n",
    "   \n",
    "def enc_padding_mask(inputs):\n",
    "    computed_mask=tf.keras.layers.Reshape((1, 1, -1))(inputs)\n",
    "    return computed_mask   \n",
    "\n",
    "def causal_mask(inputs):\n",
    "    seq_len=inputs.shape[1]\n",
    "    causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)\n",
    "    return causal_mask\n",
    "\n",
    "def dec_padding_mask(inputs, cau_mask):\n",
    "    padding_mask=  enc_padding_mask(inputs)\n",
    "    padding_mask = MinimumScalar(scalar=cau_mask)(padding_mask)\n",
    "    return padding_mask\n",
    "\n",
    "def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable = True):\n",
    "    \"\"\"LayerNormalization\"\"\"\n",
    "    if isinstance(axes, int): axes = [axes]\n",
    "        \n",
    "    mean = Mean(axes=axes, dtype=x.dtype)(x)\n",
    "    ## This block can be replaced with a squared_difference layer ##\n",
    "    diff = Subtract()([x, StopGradient()(mean)])                  ##\n",
    "    sq_diff = Multiply()([diff, diff])                            ##\n",
    "    ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##\n",
    "    variance = Mean(axes=axes,dtype=x.dtype ,name=f'{name}/variance')(sq_diff)\n",
    "    if not trainable:\n",
    "        inv = RSqrt()(variance)\n",
    "        x = Multiply()([diff, inv])\n",
    "    else:\n",
    "        # MaximumScalar prevents division by 0.\n",
    "        inv = RSqrt()(MaximumScalar(epsilon)(variance))\n",
    "        # This layer is removed for inference so it is named.\n",
    "        x = Subtract(name=f'{name}/grad_subtract')([x, mean]) \n",
    "        x = Multiply()([x, inv])\n",
    "\n",
    "    x = Scale(axes=axes)(x)\n",
    "    x = Centre(axes=axes)(x)\n",
    "    \n",
    "    return x\n",
    "\n",
    "def mlp(x, hidden_dim, out_dim=None):\n",
    "    \"\"\"Multi-layer perceptron block\"\"\"\n",
    "    if out_dim is None: out_dim = x.shape[-1]\n",
    "\n",
    "    x = tf.keras.layers.Dense(units=hidden_dim)(x)\n",
    "    x = Relu()(x)\n",
    "    x = tf.keras.layers.Dense(units=out_dim)(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(d) Build end-to-end model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "def get_translation_model(input_shape, batch_size=batch_size, seq_len=seq_len, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, freeze= False, trainable=True):\n",
    "    \n",
    "    aux_output=defaultdict(list)\n",
    "    ## Encoder\n",
    "    \n",
    "    # Input to encoder\n",
    "    enc_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"encoder_inputs\")\n",
    "    encoder_inputs=Cast()(enc_inputs)\n",
    "    \n",
    "    x = AddPositionalEmbeddingForEncoderDecoder(encoder_inputs, seq_len, vocab_size, embed_dim, 'encoder', freeze)\n",
    "    encoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"encoder_masks\")\n",
    "    encoder_padding_mask = enc_padding_mask(encoder_padding_mask_inputs)\n",
    "\n",
    "    # Encoder Attention block\n",
    "    attention_output, attention_weights = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=encoder_padding_mask, name=(f'mha'), block_name=(f'encoder'))\n",
    "    proj_input = tf.keras.layers.Add()([x, attention_output])\n",
    "    proj_input = layer_norm(proj_input, name=(f'layer_norm'), trainable=trainable)\n",
    "\n",
    "    # MLP block\n",
    "    proj_output = mlp(proj_input, latent_dim, embed_dim)\n",
    "    x = tf.keras.layers.Add()([proj_input, proj_output])\n",
    "    encoder_outputs = layer_norm(x, name=(f'layer_norm_1'), trainable=trainable)\n",
    "    \n",
    "    ## Decoder\n",
    "    \n",
    "    # Input to decoder\n",
    "    dec_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"decoder_inputs\")\n",
    "    decoder_inputs=Cast()(dec_inputs)\n",
    "\n",
    "    x = AddPositionalEmbeddingForEncoderDecoder(decoder_inputs, seq_len, vocab_size, embed_dim, 'decoder', freeze)\n",
    "    decoder_causal_mask = causal_mask(decoder_inputs)\n",
    "    decoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"decoder_masks\")\n",
    "    decoder_padding_mask = dec_padding_mask(decoder_padding_mask_inputs, decoder_causal_mask)\n",
    "    \n",
    "    \n",
    "    # Decoder Attention Block 1\n",
    "    attention_output_1, attention_weights_1 = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=decoder_causal_mask, name=(f'mha_1'), block_name=(f'decoder_1'))\n",
    "    x1 = tf.keras.layers.Add()([x, attention_output_1])\n",
    "    out_1 = layer_norm(x1,  name=(f'layer_norm_2'), trainable=trainable)\n",
    "    \n",
    "    # Decoder Attention Block 2\n",
    "    attention_output_2, attention_weights_2 = self_attention(out_1, encoder_outputs, encoder_outputs, num_heads, embed_dim*num_heads, mask=decoder_padding_mask, name=(f'mha_2'), block_name=(f'decoder_2'))\n",
    "    x2 = tf.keras.layers.Add()([out_1, attention_output_2])\n",
    "    out_2 = layer_norm(x2,  name=(f'layer_norm_3'), trainable=trainable)\n",
    "    \n",
    "    # MLP Block\n",
    "    proj_output = mlp(out_2, latent_dim, embed_dim)\n",
    "    x3 =  tf.keras.layers.Add()([out_2, proj_output])\n",
    "    x3 = layer_norm(x3, name=(f'layer_norm_4'), trainable=trainable)\n",
    "    \n",
    "\n",
    "    x3 = tf.keras.layers.Dropout(0.5)(x3)\n",
    "    x3 = tf.keras.layers.Dense(units=vocab_size, name=\"dense_last\", activation='softmax')(x3)\n",
    "\n",
    "    transformer = keras.Model(\n",
    "        [enc_inputs,encoder_padding_mask_inputs, dec_inputs, decoder_padding_mask_inputs], x3, name=\"transformer\"\n",
    "    )\n",
    "    \n",
    "    return transformer\n",
    "\n",
    "tf.keras.backend.clear_session()  # reset layer name counters\n",
    "\n",
    "transform = get_translation_model(input_shape = (seq_len,), batch_size = batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(e) Train the FP32 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform.summary()\n",
    "train(transform, model_type='fp32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(f) Evaluate Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get BLEU score on test set for FP32 transformer model\n",
    "get_text_result(transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get accuracy on test set for the FP32 transformer model from Keras example\n",
    "evaluate(transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9. Convert FP32 model to FP32 tflite model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(a) Generate a non-optimized tflite (float32 operations) file for FP32 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = tf.keras.Input(shape=(20,), batch_size=1)\n",
    "j = tf.keras.Input(shape=(20,), batch_size=1)\n",
    "k = tf.keras.Input(shape=(20,), batch_size=1)\n",
    "l = tf.keras.Input(shape=(20,), batch_size=1)\n",
    "net = tf.keras.Model(inputs=[i, j,k,l,], outputs=transform.call([i,j,k,l]))\n",
    "\n",
    "MODEL_PATH = './encoder_decoder_fp32.tflite'\n",
    "\n",
    "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n",
    "tflite_model = converter.convert()\n",
    "with open(MODEL_PATH, \"wb+\") as tflite_file:\n",
    "    tflite_file.write(tflite_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(b) Evaluate performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_tflite_accuracy(MODEL_PATH, input_type='float32', output_type='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_text_result_tflite(MODEL_PATH, input_type='float32', output_type='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Model size: \", get_gzipped_model_size(MODEL_PATH), ' KB')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10. Perform QAT on FP32 model with TFMOT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(a) To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.\n",
    "\n",
    "For Keras layers which are already supported in TFMOT, a default QuantizeConfig class is assigned to each one. However, custom QuantizeConfig instances could also be created for these layers to give more control over how they are quantised."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers\n",
    "\n",
    "LastValueQuantizer = quantizers.LastValueQuantizer\n",
    "MovingAverageQuantizer = quantizers.MovingAverageQuantizer\n",
    "AllValuesQuantizer = quantizers.AllValuesQuantizer\n",
    "\n",
    "class NoOpQuantizeConfig(QuantizeConfig):\n",
    "    \"\"\"QuantizeConfig which does not quantize any part of the layer.\"\"\"\n",
    "\n",
    "    def get_weights_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def get_activations_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def set_quantize_weights(self, layer, quantize_weights):\n",
    "        pass\n",
    "\n",
    "    def set_quantize_activations(self, layer, quantize_activations):\n",
    "        pass\n",
    "\n",
    "    def get_output_quantizers(self, layer):\n",
    "        return []\n",
    "        \n",
    "    def get_config(self):\n",
    "        return {}\n",
    "\n",
    "\n",
    "class TFOpQuantizeConfig(QuantizeConfig):\n",
    "    \"\"\"QuantizeConfig which only quantizes the output of a layer.\"\"\"\n",
    "\n",
    "    def get_weights_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def get_activations_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def set_quantize_weights(self, layer, quantize_weights):\n",
    "        pass\n",
    "\n",
    "    def set_quantize_activations(self, layer, quantize_activations):\n",
    "        pass\n",
    "\n",
    "    def get_output_quantizers(self, layer):\n",
    "        return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n",
    "\n",
    "    def get_config(self):\n",
    "        return {}\n",
    "\n",
    "\n",
    "class MaskOpQuantizeConfig(QuantizeConfig):\n",
    "    \"\"\"QuantizeConfig which only quantizes the output of a layer and is meant for the input masks.\"\"\"\n",
    "\n",
    "    def get_weights_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def get_activations_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def set_quantize_weights(self, layer, quantize_weights):\n",
    "        pass\n",
    "\n",
    "    def set_quantize_activations(self, layer, quantize_activations):\n",
    "        pass\n",
    "\n",
    "    def get_output_quantizers(self, layer):\n",
    "        return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n",
    "\n",
    "    def get_config(self):\n",
    "        return {}\n",
    "\n",
    "    \n",
    "class VarianceQuantizeConfig(QuantizeConfig):\n",
    "    \"\"\"QuantizeConfig for the variance calculation in the layer normalisation layer.\"\"\"\n",
    "\n",
    "    def get_weights_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def get_activations_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def set_quantize_weights(self, layer, quantize_weights):\n",
    "        pass\n",
    "\n",
    "    def set_quantize_activations(self, layer, quantize_activations):\n",
    "        pass\n",
    "\n",
    "    def get_output_quantizers(self, layer):\n",
    "        return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n",
    "\n",
    "    def get_config(self):\n",
    "        return {}\n",
    "        \n",
    "\n",
    "class WeightQuantizeConfig(QuantizeConfig):\n",
    "    \"\"\"QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,\n",
    "                                                   symmetric=True, narrow_range=True)\n",
    "        self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,\n",
    "                                                           symmetric=False, narrow_range=False)\n",
    "\n",
    "    def get_weights_and_quantizers(self, layer):\n",
    "        return [(layer.w, self.weight_quantizer)]\n",
    "\n",
    "    def get_activations_and_quantizers(self, layer):\n",
    "        return []\n",
    "\n",
    "    def set_quantize_weights(self, layer, quantize_weights):\n",
    "        layer.w = quantize_weights[0]\n",
    "\n",
    "    def set_quantize_activations(self, layer, quantize_activations):\n",
    "        pass\n",
    "\n",
    "    def get_output_quantizers(self, layer):\n",
    "        return [self.activation_quantizer]\n",
    "\n",
    "    def get_config(self):\n",
    "        return {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(b) Define wrapper function\n",
    "\n",
    "Since custom layers and QuantizeConfigs are used, the whole model cannot directly be wrapped with QAT wrappers.\n",
    "So first we write a function to wrap the individual layers with QAT wrappers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_wrapper(wrapper_function, layer_param_dict):\n",
    "    \n",
    "    def wrap_layer(layer):\n",
    "        if layer.name in layer_param_dict.keys():\n",
    "            return wrapper_function(layer, **layer_param_dict[layer.name])\n",
    "        return layer\n",
    "\n",
    "    return wrap_layer\n",
    "\n",
    "def layer_wrapper(model, wrapper_function, layer_param_dict):\n",
    "    return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(c) Assign QuantizeConfigs to custom layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_quantize_config(model):\n",
    "    layer_param_dict = {}  # stores {Layer_Name: QuantizeConfig} pairs\n",
    "    scope = {}  # stores all custom objects\n",
    "\n",
    "    for layer in model.layers:\n",
    "            \n",
    "            if layer.name.startswith(('clip', 'minimum', 'minimum_scalar', 'maximum_scalar', 'cast', 'stop_gradient')):\n",
    "                layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}\n",
    "                scope[layer.__class__.__name__] = layer.__class__\n",
    "            \n",
    "            elif 'grad_subtract' in layer.name or layer.name.startswith(('mat_mul', 'multiply', 'scalar_multiply', 'add',\n",
    "                                                                         'scalar_add', 'slice', 'mean', 'subtract',\n",
    "                                                                         'scalar_subtract', 'r_sqrt', 'relu')):\n",
    "                layer_param_dict[layer.name] = {'quantize_config': TFOpQuantizeConfig()}\n",
    "                scope[layer.__class__.__name__] = layer.__class__\n",
    "                \n",
    "            elif layer.name.startswith(( 'scale', 'centre', 'positional_embedding', 'token_embedding' )):\n",
    "                layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}\n",
    "                scope[layer.__class__.__name__] = layer.__class__\n",
    "\n",
    "            # Make sure to quantise the encoder and decoder mask input layers so that they can be quantized to INT8\n",
    "            \n",
    "            elif layer.name.startswith(('encoder_masks', 'decoder_masks' )):\n",
    "                layer_param_dict[layer.name] = {'quantize_config': MaskOpQuantizeConfig()}\n",
    "\n",
    "            elif 'variance' in layer.name:\n",
    "                layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}\n",
    "                scope[layer.__class__.__name__] = layer.__class__\n",
    "        \n",
    "    scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig\n",
    "    scope['TFOpQuantizeConfig'] = TFOpQuantizeConfig\n",
    "    scope['WeightQuantizeConfig'] = WeightQuantizeConfig\n",
    "    scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig\n",
    "    scope['MaskOpQuantizeConfig'] = MaskOpQuantizeConfig\n",
    "\n",
    "    return layer_param_dict, scope\n",
    "\n",
    "layer_param_dict, scope = get_quantize_config(transform)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Few layers like `cast`, `encoder_inputs` and `decoder_inputs` musn't be annontated with any QuantizeConfig as this will result into a `quantize` node being added after the inputs in tflite graph, which would pass down an int8 value to the `tfl.gather` operation. <br>And since, the `tfl.gather` operation expects only int32 and int64 as the indices, an int8 value in the `tfl.gather` operation will result into error ([Please refer TF Lite Ops Page](https://www.tensorflow.org/mlir/tfl_ops#tflgather_mlirtflgatherop))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_unwanted_layers(model, layer_param_dict):\n",
    "    # All the layers that don't need quantization can be added along side 'cast', 'encoder_inputs' and 'decoder_inputs'\n",
    "    layers_to_not_quantise = [x.name for x in model.layers if not any([y in x.name for y in ['cast', 'encoder_inputs', 'decoder_inputs'\n",
    "                                                                                            ]])]\n",
    "    layer_param_dict = {k: v for k, v in layer_param_dict.items() if k in layers_to_not_quantise}\n",
    "    for k in layers_to_not_quantise:\n",
    "        if k not in layer_param_dict:\n",
    "            layer_param_dict[k] = {'quantize_config': None}\n",
    "\n",
    "    return layer_param_dict\n",
    "\n",
    "layer_param_dict = remove_unwanted_layers(transform, layer_param_dict)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(d) Load the necessary API classes/functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n",
    "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
    "quantize_apply = tfmot.quantization.keras.quantize_apply\n",
    "quantize_scope = tfmot.quantization.keras.quantize_scope"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(e) Annotate individual layers\n",
    "\n",
    "When calling the quantize_apply function, if an unsupported layer is missing from the scope, TFMOT will throw an error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrap each custom layer with the corresponding QuantizeConfig:\n",
    "\n",
    "qat_model = layer_wrapper(transform, quantize_annotate_layer, layer_param_dict)\n",
    "\n",
    "with quantize_scope(scope):\n",
    "    qat_model = quantize_apply(qat_model)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(f) Perform QAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qat_model.summary()\n",
    "train(qat_model, model_type='qat', epochs=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(g) Evaluate Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_text_result(qat_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(qat_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 11. Create INT8 tflite file for QAT FP32 model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we attempt to directly generate a TFLite file using the fine-tuned model above:\n",
    "\n",
    "- It will not have a correct batch size of 1.\n",
    "- It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `MaximumScalar` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.\n",
    "\n",
    "Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(a) Remove layers which are not required"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.keras.backend.clear_session()  # reset layer name counters\n",
    "\n",
    "new_qat_model = get_translation_model(input_shape = (seq_len,), batch_size = batch_size, trainable = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(b) Annotate individual layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the QuantizeConfig and Scope which would be used to annotate the layers\n",
    "layer_param_dict, scope = get_quantize_config(new_qat_model)\n",
    "# Remove unwanted QuantizeConfigs\n",
    "layer_param_dict = remove_unwanted_layers(new_qat_model, layer_param_dict)    \n",
    "\n",
    "new_qat_model = layer_wrapper(new_qat_model, quantize_annotate_layer, layer_param_dict)\n",
    "\n",
    "with quantize_scope(scope):\n",
    "    new_qat_model = quantize_apply(new_qat_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(c) Load weights into the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_qat_model.load_weights('./eng_spa_transformer_qat_tutorial_qat_model.h5', by_name=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sanity check to see if weights are loaded correctly\n",
    "evaluate(new_qat_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(d) Create tflite file (int8 ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)\n",
    "j = tf.keras.Input(shape=(20,), batch_size=1)\n",
    "k = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)\n",
    "l = tf.keras.Input(shape=(20,), batch_size=1)\n",
    "\n",
    "# The following is done to ensure that the batch size of input in\n",
    "# tflite graph is 1\n",
    "net = tf.keras.Model(inputs=[i, j,k,l,], outputs=new_qat_model.call([i,j,k,l]))\n",
    "\n",
    "MODEL_PATH = './encoder_decoder_qat.tflite'\n",
    "\n",
    "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n",
    "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
    "\n",
    "# The following two lines ensure that the mask inputs\n",
    "# and the output are int8\n",
    "converter.inference_input_type = tf.int8\n",
    "converter.inference_output_type = tf.int8\n",
    "\n",
    "# Toggle this option to fold/unfold batchmatmul\n",
    "converter._experimental_disable_batchmatmul_unfold = True\n",
    "\n",
    "tflite_model = converter.convert()\n",
    "with open(MODEL_PATH, \"wb+\") as tflite_file:\n",
    "    tflite_file.write(tflite_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(e) Evaluate Performance\n",
    "\n",
    "NOTE: These steps are slow to execute therefore, the number of samples on which evaluation is performed is set to 200 by default (but definitely can be modified by the user)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_tflite_accuracy(MODEL_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "get_text_result_tflite(MODEL_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Model size: \", get_gzipped_model_size(MODEL_PATH), ' KB')"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "cb8a88e82453314166ba9bb471422eb5142b42682e25f5c554fbcb3447334d71"
  },
  "kernelspec": {
   "display_name": "Python 3.6.9 64-bit ('venv': venv)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
